import pickle
import gzip
import numpy as np

import tensorflow as tf


def stan(data):
    max_ = np.max(abs(data), axis=0)

    # sigma[np.where(sigma == 0)] = 1
    max_[np.where(max_ < 1e-6)] = 1

    return data/max_


def load_batch_gcnn(sample_files):
    """
    Loads and concatenates a bunch of samples into one mini-batch.
    """
    c_features = []
    e_indices = []
    e_features = []
    v_features = []
    cand_choices = []

    # load samples
    for filename in sample_files:
        with gzip.open(filename, 'rb') as f:
            sample = pickle.load(f)

        sample_state, updated_v_features, sample_choices = sample['data']

        c, e, v, _, _, _, _, _ = sample_state
        c_features.append(stan(c['values']))
        e_indices.append(e['indices'])
        e_features.append(stan(e['values']))

        # concat dynamic features
        variable_features = np.concatenate((v['values'], updated_v_features.transpose(1,0)), axis=1)

        v_features.append(stan(variable_features))

        idx = np.array([0 for i in range(variable_features.shape[0])])
        idx[sample_choices] = 1

        cand_choices.append(idx)

    n_cs_per_sample = [c.shape[0] for c in c_features]
    n_vs_per_sample = [v.shape[0] for v in v_features]

    # concatenate samples in one big graph
    c_features = np.concatenate(c_features, axis=0)
    v_features = np.concatenate(v_features, axis=0)
    e_features = np.concatenate(e_features, axis=0)
    cand_choices = np.concatenate(cand_choices, axis=0)
    # edge indices have to be adjusted accordingly
    cv_shift = np.cumsum([
            [0] + n_cs_per_sample[:-1],
            [0] + n_vs_per_sample[:-1]
        ], axis=1)
    e_indices = np.concatenate([e_ind + cv_shift[:, j:(j+1)]
        for j, e_ind in enumerate(e_indices)], axis=1)
    # candidate indices as well
    cand_choices = np.array(cand_choices)

    # convert to tensors
    c_features = tf.convert_to_tensor(c_features, dtype=tf.float32)
    e_indices = tf.convert_to_tensor(e_indices, dtype=tf.int32)
    e_features = tf.convert_to_tensor(e_features, dtype=tf.float32)
    v_features = tf.convert_to_tensor(v_features, dtype=tf.float32)
    n_cs_per_sample = tf.convert_to_tensor(n_cs_per_sample, dtype=tf.int32)
    n_vs_per_sample = tf.convert_to_tensor(n_vs_per_sample, dtype=tf.int32)
    cand_choices = tf.convert_to_tensor(cand_choices, dtype=tf.int32)

    return c_features, e_indices, e_features, v_features, n_cs_per_sample, n_vs_per_sample, cand_choices
